{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3E96e1UKQ8uR"
      },
      "source": [
        "# MobileBERT QAT Tutorial\n",
        "\n",
        "This notebook provides a basic example code to build, run, and fine-tune [MobileBERT](https://arxiv.org/pdf/2004.02984.pdf) with QAT toolkit.\n",
        "\n",
        "Pretrained models downloaded from the [TensorFlow Hub](https://tfhub.dev/google/qat/nlp/mobilebert_xs_qat) and the [TensorFlow Model Garden](https://github.com/tensorflow/models/tree/master/official/projects/qat/nlp), which are both trained on [SQuAD](https://deepmind.com/research/open-source/kinetics) dateset for Q\u0026A task. You will run inference the models with dummy inputs."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8_oLnvJy7kz5"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s3khsunT7kWa"
      },
      "outputs": [],
      "source": [
        "# Install packages\n",
        "\n",
        "# tf-models-official is the stable Model Garden package\n",
        "# tf-models-nightly includes latest changes\n",
        "!pip install -q tf-models-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dI_1csl6Q-gH"
      },
      "outputs": [],
      "source": [
        "# Run imports\n",
        "import os\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hjFoFGXcv7cA"
      },
      "source": [
        "## Launch QAT Training\n",
        "\n",
        "Follow the [training guideline](https://github.com/tensorflow/models/tree/master/official/projects/qat/nlp#training) to start QAT training using the pretrained checkpoint."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tz68XJhdwzDk"
      },
      "source": [
        "## Running model from TFHub\n",
        "\n",
        "Running QAT trained MobileBERT model from tfhub. Note that it contains Fake-quant op and all ops are float32. It becomes actual int8 op when you convert them to TFLite using TFLite converter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cJilACgPw10a",
        "outputId": "a1d9ec5b-7a17-440f-e445-9fc868dcada4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(1, 384)\n",
            "(1, 384)\n"
          ]
        }
      ],
      "source": [
        "loaded_obj = hub.load(\"https://tfhub.dev/google/qat/nlp/mobilebert_xs_qat/1\")\n",
        "serving_model = loaded_obj.signatures['serving_default']\n",
        "\n",
        "# Dummy inputs\n",
        "input_type_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)\n",
        "input_word_ids = tf.zeros(shape=[1, 384], dtype=tf.int32)\n",
        "input_mask = tf.zeros(shape=[1, 384], dtype=tf.int32)\n",
        "\n",
        "bert_inputs = dict(\n",
        "    input_type_ids=input_type_ids, input_word_ids=input_word_ids, input_mask=input_mask)\n",
        "\n",
        "bert_outputs = serving_model(**bert_inputs)\n",
        "\n",
        "start_logits = bert_outputs[\"start_logits\"]\n",
        "end_logits =  bert_outputs[\"end_logits\"]\n",
        "\n",
        "print(start_logits.shape)\n",
        "print(end_logits.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6g0tuFvf71S9"
      },
      "source": [
        "## Running TFLite Model Inference\n",
        "Running inference with trained quantized TFLite model with dummy dataset. We assume that data is already converted to integer from an input string using vocabulary."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QpsXxASSwfVZ",
        "outputId": "452d210a-67ea-4dfc-a8b1-f6136992ab99"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
            "                                 Dload  Upload   Total   Spent    Left  Speed\n",
            "100 33.8M  100 33.8M    0     0   102M      0 --:--:-- --:--:-- --:--:--  101M\n"
          ]
        }
      ],
      "source": [
        "# First download the TFLite model.\n",
        "! curl https://storage.googleapis.com/tf_model_garden/nlp/qat/mobilebert/model_qat.tflite --output model_qat.tflite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rw1OsQ_O1LJC"
      },
      "outputs": [],
      "source": [
        "def get_dequantized_tensor(interpreter, output_detail):\n",
        "  if ('quantization' not in output_detail or\n",
        "      np.dtype(output_detail['dtype']) == np.dtype(np.float32)):\n",
        "    return interpreter.get_tensor(output_detail['index'])\n",
        "  output_scale, output_zero_point = output_detail['quantization']\n",
        "  return (np.array(interpreter.get_tensor(output_detail['index']), dtype=np.float32) - output_zero_point) * output_scale\n",
        "\n",
        "def run_tflite(interpreter, input_word_ids, input_mask, input_type_ids):\n",
        "  input_word_ids_index, input_mask_index, input_type_ids_index = [\n",
        "      detail['index'] for detail in interpreter.get_input_details()]\n",
        "  interpreter.set_tensor(input_word_ids_index, input_word_ids)\n",
        "  interpreter.set_tensor(input_mask_index, input_mask)\n",
        "  interpreter.set_tensor(input_type_ids_index, input_type_ids)\n",
        "  interpreter.invoke()\n",
        "\n",
        "  start_logits_detail, end_logits_detail = interpreter.get_output_details()\n",
        "\n",
        "  return get_dequantized_tensor(interpreter, start_logits_detail), get_dequantized_tensor(interpreter, end_logits_detail)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cnN9RLht1kOA"
      },
      "outputs": [],
      "source": [
        "tflite_file = 'model_qat.tflite'\n",
        "with open(tflite_file, 'rb') as fp:\n",
        "  tflite_model = fp.read()\n",
        "\n",
        "interpreter = tf.lite.Interpreter(\n",
        "    model_content=tflite_model,\n",
        "    experimental_preserve_all_tensors=True)\n",
        "interpreter.allocate_tensors()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qzTgVDw51vGY",
        "outputId": "9e75bd54-6d2f-4fb6-c462-fe7fac854a27"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(1, 384)\n",
            "(1, 384)\n"
          ]
        }
      ],
      "source": [
        "# Dummy inputs\n",
        "input_type_ids = np.zeros(shape=[1, 384], dtype=np.int32)\n",
        "input_word_ids = np.zeros(shape=[1, 384], dtype=np.int32)\n",
        "input_mask = np.zeros(shape=[1, 384], dtype=np.int32)\n",
        "\n",
        "start_logits, end_logits = run_tflite(interpreter, input_type_ids, input_word_ids, input_mask)\n",
        "\n",
        "print(start_logits.shape)\n",
        "print(end_logits.shape)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
